# -*- coding:utf-8 -*-
"""
coco tools
Written by Sanjun
rename the original json categories id to new label_name
"""

from __future__ import print_function
import os, sys
import shutil
import numpy as np
import json
from random import shuffle
from collections import OrderedDict

#########################################################
# coco detailed label information
#{
#    "info": info, # dict
#    "licenses": [license], # list, inner is dict
#    "images": [image], # list, inner is dict
#    "annotations": [annotation], # list, inner is dict
#    "categories": # list, inner is dict
#}
#########################################################

#'./annotations/instances_val2017.json' # Object Instance label info
# person_keypoints_val2017.json  # Object Keypoint label info
# captions_val2017.json  # Image Caption label info

label_name = OrderedDict([
	(1, "NewCat1"),
	(2, "NewCat2"),
	(3, "NewCat3"),
        ])
# according the *_cat_map.txt generated by convert2coco.py
# to set your org_labeled_name
org_labeled_name = OrderedDict([
        # ("org_id", "org_name, map_name")
	(1,  ["Cat1",       "NewCat1"]),
	(2,  ["Cat2",      "NewCat2"]),
	(5,  ["Cat5",       "NewCat3"])
        ])

# get the category_name by category_id
def get_name_byID(in_json_file, in_id):
    data = json.load(open(in_json_file, 'r'))
    for cat in data['categories']:
        cat_id = cat['id']
        if in_id == int(cat_id):
            cat_name =  str(cat['name'])
            break
    if not cat_name.isspace():
        return cat_name
    else:
        raise RuntimeError('{} has no name'.format(cat_id))

# get the dict[category_id name bbox_num]
def get_categories_id_map(in_json_file, out_dir):
    result_record = OrderedDict()
    data = json.load(open(in_json_file, 'r'))
    for ann in data['annotations']:
        cat_id = int(ann['category_id'])
        if cat_id in result_record.keys():
            result_record[cat_id][1] += 1
        else:
            result_record[cat_id] = [get_name_byID(in_json_file, cat_id), 1]

    result_record = OrderedDict([(k, result_record[k]) for k in sorted(result_record.keys())])
    file_name = os.path.basename(in_json_file).split('.')[0] + '_cat_map_new.txt'
    f = open(os.path.join(out_dir, file_name), 'w')
    for key in result_record.keys():
        f.write('{:5} {:30} {:10}\n'.format(key, result_record[key][0], result_record[key][1]))
    f.close()
    print('save key map', result_record)

def get_key (dict, value):
    return [k for k, v in dict.items() if v == value]

def get_total_categories(label):
    categories=[]
    for key, value in label.items():
        categorie={}
        categorie['supercategory'] = value
        categorie['id'] = key
        categorie['name'] = value
        categories.append(categorie)
    return categories

def rename_category_json(in_json_dir, json_file, save_dir, save_ann_name):

    # annotations
    annotation=[]

    json_file_path = os.path.join(in_json_dir, json_file)
    data = json.load(open(json_file_path, 'r'))
    #print("image num:", len(data['images']))
    for ann in data['annotations']:
        catID = ann['category_id']
        org_name = org_labeled_name[catID][0]
        new_name = org_labeled_name[catID][1]
        print(org_name, new_name)
        print(get_key(label_name, new_name))
        new_catID = get_key(label_name, new_name)[0]
        ann['category_id'] = new_catID
        annotation.append(ann)
        #print('=== rename label {}:{} ---> {}:{} ==='.format(org_name, catID,
        #    new_name, new_catID))

    data['info']=data['info']
    data['licenses']=data['licenses']
    data['images']=data['images']
    data['annotations']=annotation
    data['categories']=get_total_categories(label_name)

    # save
    save_ann_file = os.path.join(save_dir, save_ann_name)
    json.dump(data, open(save_ann_file, 'w'),indent=4) # indent=4 for more elegant to show
    print('save json done:', save_ann_file)

if __name__ == '__main__':

    in_path = '/datasets/orgCOCODir'
    out_path = '/datasets/newCOCOCatIDDir'
    if len(sys.argv) == 3:
        in_path = sys.argv[1]
        out_path = sys.argv[2]
    print("in_path:", in_path)
    print("out_path:", out_path)
    in_annotation_path = os.path.join(in_path, 'annotations')
    out_annotation_path = os.path.join(out_path, 'annotations')
    if not os.path.exists(out_annotation_path):
        os.makedirs(out_annotation_path)

    json_file = 'instances_val2018.json'
    save_file = json_file
    in_json_file = os.path.join(out_annotation_path, save_file)
    rename_category_json(in_annotation_path, json_file, out_annotation_path, save_file)

    get_categories_id_map(in_json_file, out_path)

    json_file = 'instances_train2018.json'
    save_file = json_file
    in_json_file = os.path.join(out_annotation_path, save_file)
    rename_category_json(in_annotation_path, json_file, out_annotation_path, save_file)

    get_categories_id_map(in_json_file, out_path)
